import dataset_utils
from tqdm import tqdm
import os
import argparse
from transformers import PreTrainedTokenizerFast


def getargs():
    parser = argparse.ArgumentParser()

    parser.add_argument("--outdir", type=str)
    parser.add_argument("--maxlen", type=int, default=512, required=False)
    parser.add_argument("--modelname", type=str)
    parser.add_argument("--tasks", nargs="+", default=[])
    parser.add_argument("--nproc", type=int, default=32)
    parser.add_argument("--maskrate", type=float)
    parser.add_argument("--seed", type=int)

    return parser


def get_tokenizer_path(modelname):
    if modelname == "codet5-small":
        tokenizer_path = "./artifacts/tokenizer/codet5/tokenizer"
    elif modelname == "codet5-large":
        tokenizer_path = "./artifacts/tokenizer/codet5/tokenizer"
    else:
        raise AssertionError("modelname not supported")

    return tokenizer_path


def get_tokenizer(modelname, tokenizer_path):
    if modelname == "codet5-small":
        tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
    elif modelname == "codet5-large":
        tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
    else:
        raise AssertionError("model name not supported")

    return tokenizer


if __name__ == "__main__":
    parser = getargs()
    args = parser.parse_args()

    print("Arguments: ", args)

    tokenizer_path = get_tokenizer_path(args.modelname)
    tokenizer = get_tokenizer(args.modelname, tokenizer_path)
    task_datasets = dataset_utils.get_task_datasets(args.tasks, tokenizer, args.maxlen, args.maskrate)

    for idx, task in tqdm(enumerate(args.tasks)):
        task_ds = task_datasets[idx].shuffle(args.seed + idx)
        fout = os.path.join(args.outdir, f"CSN-{task}")
        os.makedirs(fout)
        print("Saving data to disk")
        task_ds.save_to_disk(fout)
